import math
import torch
import numpy as np
from typing import TypeVar, Optional, Iterator
import torch.distributed as dist
from torch.utils.data import Sampler, Dataset


T_co = TypeVar('T_co', covariant=True)

import math
from typing import Sized
from torch.utils.data.sampler import Sampler


class BalancedSampler(Sampler[int]):

    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, args=None) -> None:
        self.dt = data_source
        self.args = args
        self.n_cls = args.num_labels
        self.n_samples = self.n_dt = len(self.dt) ## imbalanced, small의 dataset 크기 dataset에서는 어떤 값을 return하는 게 맞는가?
        self.n_per_cls = [int(self.n_dt / self.n_cls)] * self.n_cls
        self.n_cls_wise_desired = int(self.n_dt / self.n_cls)# self.n_per_cls
        self.n_repeat = [1] * self.n_cls #np.ceil(self.n_cls_wise_desired/ self.n_per_cls).astype(int)
        self.st_idx_cls = np.insert(np.cumsum(self.n_per_cls), 0, 0)[:-1]
        self.cls_idx = torch.from_numpy(self.st_idx_cls).\
            unsqueeze(1).expand(self.n_cls, int(self.n_cls_wise_desired))
        targets = torch.tensor(self.dt.targets)
        vals, self.inds = torch.sort(targets)
        # print(self.n_per_cls)
        # print(self.n_cls_wise_desired)
        # print(self.n_repeat)
        # print(self.n_samples)
        # print(self.cls_idx)
        # print(self.get_b())

    def __len__(self):
        return self.n_samples

    def num_samples(self) -> int:
        return self.n_samples

    def __iter__(self):
        b = self.get_b()
        return iter(b)
        # yield from b
        # batch_rand_perm_lst = list()
        # for i_cls in range(self.n_cls):
            # rand = torch.rand(self.n_repeat[i_cls], self.n_per_cls[i_cls])
            # brp = rand.argsort(dim=-1).reshape(-1)[:self.n_cls_wise_desired]
            # batch_rand_perm_lst.append(brp)
        # batch_rand_perm  = torch.stack(batch_rand_perm_lst, 0)
        # batch_rand_perm += self.cls_idx
        # b = batch_rand_perm.permute(1, 0).reshape(-1).tolist()
        # b = [i for i in range(50000)]
        # return iter(b)

    def get_b(self):
        batch_rand_perm_lst = list()
        for i_cls in range(self.n_cls):
            rand = torch.rand(self.n_repeat[i_cls], self.n_per_cls[i_cls])
            brp = rand.argsort(dim=-1).reshape(-1)[:self.n_cls_wise_desired]
            batch_rand_perm_lst.append(brp)
        batch_rand_perm  = torch.stack(batch_rand_perm_lst, 0)
        batch_rand_perm += self.cls_idx
        b = batch_rand_perm.permute(1, 0).reshape(-1).tolist()
        ib = self.inds[b].tolist()
        return ib


class SemiSupervisedSampler(Sampler): # not for ddp
    """Balanced sampling from the labeled and unlabeled data"""
    # 1:1로 balancing을 하는 것이 적절한가?
    def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5,
                 num_batches=None):
        if unsup_fraction is None or unsup_fraction < 0:
            self.sup_inds = sup_inds + unsup_inds
            unsup_fraction = 0.0
        else:
            self.sup_inds = sup_inds
            self.unsup_inds = unsup_inds

        self.batch_size = batch_size
        unsup_batch_size = int(batch_size * unsup_fraction)
        self.sup_batch_size = batch_size - unsup_batch_size

        if num_batches is not None:
            self.num_batches = num_batches
        else:
            self.num_batches = int(
                np.ceil(len(self.sup_inds) / self.sup_batch_size))

        super().__init__(None)


    def __iter__(self):
        batch_counter = 0
        while batch_counter < self.num_batches:
            sup_inds_shuffled = [self.sup_inds[i]
                                 for i in torch.randperm(len(self.sup_inds))]
            for sup_k in range(0, len(self.sup_inds), self.sup_batch_size):
                if batch_counter == self.num_batches:
                    break
                batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)]
                if self.sup_batch_size < self.batch_size:
                    batch.extend([self.unsup_inds[i] for i in
                                  torch.randint(high=len(self.unsup_inds),
                                                size=(
                                                    self.batch_size - len(
                                                        batch),),
                                                dtype=torch.int64)])
                # this shuffle operation is very important, without it
                # batch-norm / DataParallel hell ensues
                np.random.shuffle(batch)
                yield batch
                batch_counter += 1

    def __len__(self):
        return self.num_batches


class DistributedSampler(Sampler):

    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch


class SemiSupDistributedSampler(Sampler):

    def __init__(self, sup_inds, unsup_inds, bsz_per_gpu,
                 unsup_fraction=0.5,
                 num_batches=None,
                 num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last

        self.n_sup = len(sup_inds)
        sup_fraction = 1 - unsup_fraction
        self.dt_len = math.ceil(self.n_sup / sup_fraction)
        self.n_unsup = self.dt_len - self.n_sup
        self.unsup_fraction = unsup_fraction
        self.bsz_per_gpu = bsz_per_gpu

        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and self.dt_len % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.n_sample_per_gpu = math.ceil(
                (self.dt_len - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.n_sample_per_gpu = math.ceil(self.dt_len / self.num_replicas)  # type: ignore[arg-type]
            # gpu별 sample 수
        self.total_size = self.n_sample_per_gpu * self.num_replicas
        assert self.total_size == self.dt_len
        self.shuffle = shuffle
        self.seed = seed

        if unsup_fraction is None or unsup_fraction < 0:
            self.sup_inds = sup_inds + unsup_inds
            unsup_fraction = 0.0
        else:
            self.sup_inds = sup_inds
            self.unsup_inds = unsup_inds
        self.unsup_bsz_per_gpu = int(bsz_per_gpu * unsup_fraction)
        self.sup_bsz_per_gpu = bsz_per_gpu - self.unsup_bsz_per_gpu
        tmp = self.n_sup / (self.sup_bsz_per_gpu * self.num_replicas)
        self.num_batches = int(np.ceil(tmp))
        super().__init__(None)

    def __len__(self) -> int:
        return self.n_sample_per_gpu

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch


class SSDRatioSampler(SemiSupDistributedSampler):

    def __iter__(self) -> Iterator[T_co]:
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        sup_inds_shuffled = [self.sup_inds[i] for i in \
                             torch.randperm(self.n_sup, generator=g).tolist()]
        unsup_inds_sampled = [self.unsup_inds[i] for i in \
                              torch.randint(high=len(self.unsup_inds),
                              size=(self.n_unsup,), dtype=torch.int64,
                              generator=g)]
        sup_inds_np = np.array(sup_inds_shuffled).reshape(self.sup_bsz_per_gpu, -1)
        unsup_inds_np = np.array(unsup_inds_sampled).reshape(self.unsup_bsz_per_gpu, -1)
        np.random.shuffle(sup_inds_np); np.random.shuffle(unsup_inds_np)
        indices = np.concatenate([sup_inds_np, unsup_inds_np], axis=0).T.flatten()
        # if not self.drop_last:
            # # add extra samples to make it evenly divisible
            # padding_size = self.total_size - len(indices)
            # if padding_size <= len(indices):
                # indices += indices[:padding_size]
            # else:
                # indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        # else:
            # # remove tail of data to make it evenly divisible.
        #     indices = indices[:self.total_size]
        assert len(indices) == self.total_size
        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.n_sample_per_gpu
        return iter(indices)


class SSDFullSampler(SemiSupDistributedSampler):

    def __iter__(self) -> Iterator[T_co]:
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = self.sup_inds + self.unsup_inds
        indices = [indices[i] for i in torch.randperm(len(indices), generator=g).tolist()]
        np.random.shuffle(indices)
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.n_sample_per_gpu
        return iter(indices)




